package breakthrough;

import breakthrough.*;
import boardgame.*;
import java.io.*;
import java.util.*;

/** Training program for the BTNeuralNetPlayer.
 *  Usage:
 *    java NetTrainer <numGames> <learningRate> <layer1units> <layer2units> ... 
 *  trains by observing 'numGames' games using the learning
 *  rate 'learningRate'. Default values are 100000 and 0.001
 *  respectively. 
 *  
 *  The structure of the net is determined by the 
 *  arguments following the learningrate: specify the number
 *  of units in each hidden layer. If no number is specified,
 *  there is no hidden layer: we are training a perceptron.
 *
 *  There are always BTNeuralNetPlayer.NUM_FEATURES input units, 
 *  and a single output unit. We are trying to classify
 *  boards as winning (output=1.0) or losing (output=0.0).
 *
 *  Learning data is generated by watching games between two
 *  BTHeuristicPlayer players.
 *
 *  Two output files are written:
 *  The network is saved in 'network.txt' and the weights
 *  after every 100 training examples are stored in the
 *  file 'train.csv'.
 */
public class NetTrainer {
    // We'll need a BTNeuralNetPlayer to train...
    private BTNeuralNetPlayer player;
    // and somewhere to write output
    private PrintStream logFile;
    private static final String LOG_FILE_NAME= "train.csv";
    // The learning rate is a parameter
    private double learningRate = 0.001;

    /** Create a new Trainer object and let it do its work. */
    public static void main(String[] argv) {       
	(new NetTrainer()).run(argv);
    }

    /** Main function, does the training as specified in the
	arguments */
    public void run(String [] argv) {
	// Open an output file
	try {
	    logFile = new PrintStream
		( new FileOutputStream( LOG_FILE_NAME ));
	} catch( Exception e ) {
	    System.err.println( "Failed to create log file '" + 
				LOG_FILE_NAME + "'." );
	    e.printStackTrace();
	    return;
	}
	
	// Parse arguments
	int numGames = argv.length >= 1 ? 
	    Integer.parseInt( argv[0] ) : 100000;
	learningRate = argv.length >= 2 ? 
	    Double.parseDouble( argv[1] ) : 0.001;  
	int numLayers = argv.length - 2 + 2;
	if( numLayers < 2 ) numLayers = 2;
	int structure[] = new int[numLayers];
	structure[0] = BTNeuralNetPlayer.NUM_FEATURES;
	structure[numLayers-1] = 1;
	for( int i = 1; i < numLayers-1; i++ )
	    structure[i] = Integer.parseInt( argv[i+1] );
	
	// Create the player
	player = new BTNeuralNetPlayer( structure );

	// INITIALIZE WEIGHTS
	player.getNeuralNet().randomizeWeights();
	System.out.println( "Initial weights:" );
	player.getNeuralNet().printWeights(System.out);
	System.out.println( "" );

	// DO TRAINING
	trainFromHeuristicPlayer( numGames );

	// SAVE WEIGHTS TO FILE
	System.out.println( "Final weights:" );
	player.getNeuralNet().printWeights(System.out);
	System.out.println( "" );
	try {
	    player.getNeuralNet().
		save( BTNeuralNetPlayer.NETWORK_FILE );
	} catch (Exception e) {
	    System.err.println( "Failed to save weights to '" + 
				BTNeuralNetPlayer.NETWORK_FILE + 
				"'." );
	    e.printStackTrace();
	}
	logFile.close();	
    }

    /** Train from watching BTHeuristicPlayer's */
    public void trainFromHeuristicPlayer( int numGames ) {
	System.out.print( "Training for " + numGames +
			  " games between two BTHeuristicPlayers..." );
	Player pl = new BTHeuristicPlayer();
	train(numGames, pl, pl);
	System.out.println( " Done." );
    }

    /** Simulate a number of games between two players, and learn 
     * from each game. 
     * @param numGames the number of games to play 
     * @param player1 one of the players used to simulate games
     * @param player2 the other player
     */
    public void train( int numGames, Player player1, Player player2 ) {
	// This list will hold the boards encountered during the game
	ArrayList boards = new ArrayList(250);

	// Play numGames games
	for( int game = 0; game < numGames; game++ ) {
	    BTBoard theBoard = new BTBoard();
	    
	    // Play a game, saving the boards as we go along
	    while( theBoard.getWinner() == BTBoard.NOBODY ) {
		if( theBoard.getTurn() == BTBoard.WHITE )
		    theBoard.move( player1.chooseMove( theBoard ) );
		else
		    theBoard.move( player2.chooseMove( theBoard ) );
		// Save the board after each move
		boards.add( theBoard.clone() );
	    }

	    // Train our perceptron based on the boards
	    learnFromGame( boards );

	    // Append the current weights to the log after 
	    // every 100 games.
	    if( (game+1) % 100 == 0 ) {
		player.getNeuralNet().printWeights(logFile);
		logFile.println( "" );
		System.out.print( '.' );
	    }

	    // Forget this game
	    boards.clear();

	    // Swap players
	    Player tmp = player1;
	    player1 = player2;
	    player2 = tmp;

	}
    }


    /** Update weights based on one game. The list contains all the boards 
     *	occurring during the game, starting from just after the first move.
     *  @param boards the list of BTBoards 
     */
    public void learnFromGame( ArrayList boards ) {
	double desiredOutput[] = new double[1];
	// Find out who won the game. 
	int winner = ((BTBoard)boards.get(boards.size()-1)).getWinner();
	if( winner != BTBoard.BLACK && winner != BTBoard.WHITE )
	    throw new IllegalArgumentException();	    
	// Make a training example out of each game. If WHITE just played,
	// then it is an example from the WHITE point of view. Otherwise, 
	// it is an example from the BLACK point of view. The desired 
	// output is 0.0 if the player loses the game, and 1.0 if she wins.
	for( int i = 0; i < boards.size(); i++ ) {
	    BTBoard b = (BTBoard) boards.get(i);
	    // Since the list of boards starts after the first move,
	    // the last player to have played will be WHITE if i is even
	    // and BLACK if i is odd.
	    int lastPlayer = (i % 2 == 0) ? BTBoard.WHITE : BTBoard.BLACK;
	    desiredOutput[0] = (winner == lastPlayer) ? 1.0 : 0.0;
	    // Get the features from the board
	    double[] features = player.featurize( b, lastPlayer );
	    // Train on the example
	    player.getNeuralNet().train( features, 
					 desiredOutput, learningRate );
	}
    }
}
